{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:38:27.890197Z",
     "iopub.status.busy": "2022-06-14T17:38:27.890007Z",
     "iopub.status.idle": "2022-06-14T17:38:29.221575Z",
     "shell.execute_reply": "2022-06-14T17:38:29.220839Z",
     "shell.execute_reply.started": "2022-06-14T17:38:27.890178Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "%pylab inline\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "431a9196-30ec-4f86-8da7-cd6e2edc7154",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:41:25.725907Z",
     "iopub.status.busy": "2022-06-14T17:41:25.725277Z",
     "iopub.status.idle": "2022-06-14T17:41:25.737498Z",
     "shell.execute_reply": "2022-06-14T17:41:25.736974Z",
     "shell.execute_reply.started": "2022-06-14T17:41:25.725857Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_df = pd.read_csv('./LED_part1_train/train.csv', sep=' ')\n",
    "train_df['file'] = train_df['file'].apply(lambda x: str(x).zfill(6) + '.bmp')\n",
    "train_df = train_df.sample(frac=1.0)\n",
    "\n",
    "train_path = './LED_part1_train/imgs/' + train_df['file']\n",
    "train_label = train_df['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "8cf398db-d606-493f-b942-8e8d717e1b4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:48:47.774884Z",
     "iopub.status.busy": "2022-06-14T17:48:47.774390Z",
     "iopub.status.idle": "2022-06-14T17:48:47.783309Z",
     "shell.execute_reply": "2022-06-14T17:48:47.782663Z",
     "shell.execute_reply.started": "2022-06-14T17:48:47.774843Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_df = pd.read_csv('提交示例.csv')\n",
    "test_df['path'] = test_df['file'].apply(lambda x: str(x.split('_')[1]).zfill(6) + '.bmp')\n",
    "test_path = './LED_part1_train/imgs/' + test_df['path']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "3e604253-e8f0-4fd4-bafd-aa0ffe8ddcd7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:52:10.777464Z",
     "iopub.status.busy": "2022-06-14T17:52:10.776927Z",
     "iopub.status.idle": "2022-06-14T17:52:10.785118Z",
     "shell.execute_reply": "2022-06-14T17:52:10.784073Z",
     "shell.execute_reply.started": "2022-06-14T17:52:10.777421Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "set()"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(test_path) & set(train_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "584f2188-cf45-4d72-abee-4dae0d362926",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:40:45.725403Z",
     "iopub.status.busy": "2022-06-14T17:40:45.724822Z",
     "iopub.status.idle": "2022-06-14T17:40:45.734679Z",
     "shell.execute_reply": "2022-06-14T17:40:45.733977Z",
     "shell.execute_reply.started": "2022-06-14T17:40:45.725355Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, img_label, transform=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = cv2.imread(self.img_path[index])            \n",
    "        img = img.astype(np.float32)\n",
    "        \n",
    "        img /= 255.0\n",
    "        img -= 1\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(image = img)['image']\n",
    "        img = img.transpose([2,0,1])\n",
    "        return img,torch.from_numpy(np.array(self.img_label[index]))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:40:49.844310Z",
     "iopub.status.busy": "2022-06-14T17:40:49.843731Z",
     "iopub.status.idle": "2022-06-14T17:40:49.852994Z",
     "shell.execute_reply": "2022-06-14T17:40:49.851873Z",
     "shell.execute_reply.started": "2022-06-14T17:40:49.844261Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "                \n",
    "        model = models.resnet18(True)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model.fc = nn.Linear(512, 4)\n",
    "        self.resnet = model\n",
    "        \n",
    "    def forward(self, img):        \n",
    "        out = self.resnet(img)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "47539bef-14a7-4881-85e9-0882ff295e6a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:40:51.726767Z",
     "iopub.status.busy": "2022-06-14T17:40:51.726181Z",
     "iopub.status.idle": "2022-06-14T17:40:51.738406Z",
     "shell.execute_reply": "2022-06-14T17:40:51.737735Z",
     "shell.execute_reply.started": "2022-06-14T17:40:51.726718Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer):\n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 20 == 0:\n",
    "            print(loss.item())\n",
    "            \n",
    "        train_loss += loss.item()\n",
    "    \n",
    "    return train_loss/len(train_loader)\n",
    "            \n",
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    \n",
    "    val_acc = 0.0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            \n",
    "            val_acc += (output.argmax(1) == target).sum().item()\n",
    "            \n",
    "    return val_acc / len(val_loader.dataset)\n",
    "\n",
    "def predict(test_loader, model, criterion):\n",
    "    model.eval()\n",
    "    val_acc = 0.0\n",
    "    \n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(test_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            test_pred.append(output.data.cpu().numpy())\n",
    "            \n",
    "    return np.vstack(test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:48:57.772099Z",
     "iopub.status.busy": "2022-06-14T17:48:57.771517Z",
     "iopub.status.idle": "2022-06-14T17:48:57.782443Z",
     "shell.execute_reply": "2022-06-14T17:48:57.781799Z",
     "shell.execute_reply.started": "2022-06-14T17:48:57.772050Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import albumentations as A\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path.values[:-500], train_label.values[:-500],\n",
    "            A.Compose([\n",
    "            # A.Resize(300, 300),\n",
    "            A.RandomCrop(130, 450),\n",
    "            A.HorizontalFlip(p=0.5),\n",
    "            A.RandomContrast(p=0.5),\n",
    "            A.RandomBrightnessContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=10, shuffle=True, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path.values[-500:], train_label.values[-500:],\n",
    "            A.Compose([\n",
    "            # A.Resize(300, 300),\n",
    "            A.RandomCrop(130, 450),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path, [0] * len(test_path),\n",
    "            A.Compose([\n",
    "            # A.Resize(300, 300),\n",
    "            A.RandomCrop(130, 450),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:45:26.788800Z",
     "iopub.status.busy": "2022-06-14T17:45:26.788206Z",
     "iopub.status.idle": "2022-06-14T17:45:27.016315Z",
     "shell.execute_reply": "2022-06-14T17:45:27.015873Z",
     "shell.execute_reply.started": "2022-06-14T17:45:26.788740Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = XunFeiNet()\n",
    "model = model.to('cuda')\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.SGD(model.parameters(), 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:45:28.727590Z",
     "iopub.status.busy": "2022-06-14T17:45:28.727006Z",
     "iopub.status.idle": "2022-06-14T17:46:02.609564Z",
     "shell.execute_reply": "2022-06-14T17:46:02.608928Z",
     "shell.execute_reply.started": "2022-06-14T17:45:28.727541Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1781413555145264\n",
      "1.0419951677322388\n",
      "1.0792276859283447\n",
      "0.30832964181900024\n",
      "0.8367191553115845\n",
      "1.3866803646087646\n",
      "1.5160852670669556\n",
      "1.6955101490020752\n",
      "0.7909185290336609\n",
      "0.27210813760757446\n",
      "0.1143868938088417\n",
      "0.45553913712501526\n",
      "0.5953348278999329\n",
      "0.45054784417152405\n",
      "0.33823224902153015\n",
      "0.7402018904685974\n",
      "0.3970138430595398\n",
      "0.6925562620162964\n",
      "1.2593271732330322\n",
      "0.721782386302948\n",
      "0.8106012187607876 0.882\n",
      "0.862005352973938\n",
      "0.7307425737380981\n",
      "0.5787581205368042\n",
      "0.5375865697860718\n",
      "0.962967038154602\n",
      "0.49011683464050293\n",
      "0.13052770495414734\n",
      "1.891213059425354\n",
      "0.8977954983711243\n",
      "0.5088403820991516\n",
      "0.6014495491981506\n",
      "1.0218145847320557\n",
      "0.5123348832130432\n",
      "0.5912672281265259\n",
      "1.1143639087677002\n",
      "0.24420467019081116\n",
      "1.0526559352874756\n",
      "0.4181271195411682\n",
      "0.12599065899848938\n",
      "1.1317453384399414\n",
      "0.6991669833079561 0.828\n",
      "1.0638439655303955\n",
      "0.9436470866203308\n",
      "0.2173367291688919\n",
      "0.715546727180481\n",
      "0.6696562767028809\n",
      "0.5196341276168823\n",
      "0.4729689061641693\n",
      "1.4259179830551147\n",
      "0.7534366250038147\n",
      "1.1594340801239014\n",
      "0.41260257363319397\n",
      "0.496182382106781\n",
      "1.1747467517852783\n",
      "0.5072125792503357\n",
      "0.7038043737411499\n",
      "1.145532250404358\n",
      "0.38535618782043457\n",
      "0.911472499370575\n",
      "0.44795355200767517\n",
      "0.9146186709403992\n",
      "0.6921569254104193 0.9\n"
     ]
    }
   ],
   "source": [
    "for _  in range(3):\n",
    "    train_loss = train(train_loader, model, criterion, optimizer)\n",
    "    val_acc = validate(val_loader, model, criterion)\n",
    "    \n",
    "    print(train_loss, val_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "6787d12a-b923-4777-abf5-4d90904551d1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:49:00.776785Z",
     "iopub.status.busy": "2022-06-14T17:49:00.776211Z",
     "iopub.status.idle": "2022-06-14T17:49:17.437779Z",
     "shell.execute_reply": "2022-06-14T17:49:17.436995Z",
     "shell.execute_reply.started": "2022-06-14T17:49:00.776738Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = None\n",
    "\n",
    "for _ in range(10):\n",
    "    if pred is None:\n",
    "        pred = predict(test_loader, model, criterion)\n",
    "    else:\n",
    "        pred += predict(test_loader, model, criterion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:49:24.773508Z",
     "iopub.status.busy": "2022-06-14T17:49:24.772895Z",
     "iopub.status.idle": "2022-06-14T17:49:24.781353Z",
     "shell.execute_reply": "2022-06-14T17:49:24.780697Z",
     "shell.execute_reply.started": "2022-06-14T17:49:24.773431Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit = pd.DataFrame(\n",
    "    {\n",
    "        'image': [x.split('/')[-1] for x in test_path],\n",
    "        'label': pred.argmax(1)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:49:59.776648Z",
     "iopub.status.busy": "2022-06-14T17:49:59.776064Z",
     "iopub.status.idle": "2022-06-14T17:49:59.790585Z",
     "shell.execute_reply": "2022-06-14T17:49:59.789949Z",
     "shell.execute_reply.started": "2022-06-14T17:49:59.776602Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_df = pd.read_csv('提交示例.csv')\n",
    "test_df['label'] = pred.argmax(1)\n",
    "\n",
    "# submit = submit.sort_values(by='id')\n",
    "test_df.to_csv('submit2.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129ff32f-8444-4cf2-9997-0101960ea490",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
